--- title: Time Alignment with micro-tcn keywords: fastai sidebar: home_sidebar nb_path: "02_time_align.ipynb" ---
Work in progress for NASH Hackathon, Dec 17, 2021
this is like the 01_td_demo notebook only we use a different dataset and generalize the dataloader a bit
%pip install -Uqq pip
# Next line only executes on Colab. Colab users: Please enable GPU in Edit > Notebook settings
! [ -e /content ] && pip install -Uqq fastai git+https://github.com/drscotthawley/fastproaudio.git
# Additional installs for this tutorial
%pip install -q fastai_minima torchsummary pyzenodo3 wandb
# Install micro-tcn and auraloss packages (from source, will take a little while)
%pip install -q wheel --ignore-requires-python git+https://github.com/csteinmetz1/micro-tcn.git git+https://github.com/csteinmetz1/auraloss
# After this cell finishes, restart the kernel and continue below
from fastai.vision.all import *
from fastai.text.all import *
from fastai.callback.fp16 import *
import wandb
from fastai.callback.wandb import *
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from IPython.display import Audio
import matplotlib.pyplot as plt
import torchsummary
from fastproaudio.core import *
from pathlib import Path
from glob import glob
import json
import re
We'll use Marco's guitar strike dataset (which is really from IDMT but whatever):
path_audiomdpi = get_audio_data(URLs.MARCO)
horn = path_audiomdpi / "LeslieHorn"; horn.ls()
path_dry = horn /'dry'
audio_extensions = ['.m3u', '.ram', '.au', '.snd', '.mp3','.wav']
fnames_dry = get_files(path_dry, extensions=audio_extensions)
Let's just take a look at one guitar pluck
waveform, sample_rate = torchaudio.load(fnames_dry[0])
show_audio(waveform, sample_rate)
And now the main stragegy of pasting in this one sample (IRL we'll use lots of them) along a track.
sample = waveform[0].numpy() # just simplify array dimensions for this demo
sample = sample[int(0.63*sample_rate):] # chop off the silence at the front for this demo
track_length = sample_rate*5
sample_len = sample.shape[-1]
target = np.zeros(track_length)
input = np.zeros(track_length)
click = np.zeros(track_length)
grid_interval = sample_rate
n_intervals = track_length // grid_interval
for i in range(n_intervals): # paste samples at regular intervals
start = grid_interval*i
click[start] = 1 # click track
end = min( start+sample_len, track_length)
target[start:end] = sample[0:end-start] # paste the sample on the grid
# perturb up the paste location by some amount
rand_start = max(0, start + np.random.randint(-grid_interval//2,grid_interval//2))
rand_end = min( rand_start+sample_len, track_length )
input[rand_start:rand_end] = sample[0:rand_end-rand_start]
There's some kind of click track that will be regarded as part of the multichannel Input:
fig = plt.figure(figsize=(14, 2))
plt.plot(click)
Input is randomly perturbed from the grid:
fig = plt.figure(figsize=(14, 2))
plt.plot(input)
Target is on the grid:
fig = plt.figure(figsize=(14, 2))
plt.plot(target) # target is on the grid
The job of the network is: given input and click track, produce the target.
.... this was all in mono and with only one audio sample, but IRL we'll have multiple channels of audio and a variety of audio samples. Speaking of:
Since the SignalTrain naming scheme and micro-tcn assume the existence of at least one "knob" or (conditioning parameter), we will use BPM values for the knob parameter. Those values will be surrounded by double-underscores.
The dataset(s) will exist in two forms:
_mono and it will contain Train/, Val/ and Test/ as per the SignalTrain dataset spec. For exampleinput_235-0_.wav
input_235-1_.wav
input_235-2_.wav
target_235-0__120__.wav
target_235-1__120__.wav
target_235-2__120__.wav
where "235" is in this case the common designation of the example number, -0 is the click track and -1, and -2 are audio channels, and __120__ is the BPM "knob" value. Technically the 120 should go on the input filename too since the click is part of the input but...the SignalTrain dataset filename scheme didn't work that way, and so making that that would require re-doing the data reader which we'd like to avoid doing. (Yes, it's redundant to include the click twice, but see note below.)
_mc directory name in which the mono tracks have been packed into multichannel WAV files. For those, the click track will be added to both the input and the target, and it will be channel 0. Same Train, Val, Test subdirs.path = Path('wherever jacob puts the data')
fnames_in = sorted(glob(str(path)+'/*/input*'))
fnames_targ = sorted(glob(str(path)+'/*/*targ*'))
ind = -1 # pick one spot in the list of files
fnames_in[ind], fnames_targ[ind]
Input audio
waveform, sample_rate = torchaudio.load(fnames_in[ind])
show_audio(waveform, sample_rate)
Target output audio
target, sr_targ = torchaudio.load(fnames_targ[ind])
show_audio(target, sr_targ)
Let's look at the difference.
Difference
show_audio(target - waveform, sample_rate)
def get_accompanying_tracks(fn, fn_list, remove=False):
""""Given one filename, and a list of all filenames, return a list of that filename and
any files it 'goes with'
remove: remove these accompanying files from the main list.
"""
# make a copies of fn & fn_list with all hyphen+stuff removed.
basename = re.sub(r'-[a-zA-Z0-9]+','', fn)
basename_list = [re.sub(r'-[a-zA-Z0-9]+','', x) for x in fn_list]
# get indices of all elements of basename_list matching basename, return original filenames
accompanying = [fn_list[i] for i, x in enumerate(basename_list) if x == basename]
if remove:
for x in accompanying:
if x != fn: fn_list.remove(x) # don't remove the file we search on though
return accompanying # note accompanying list includes original file too
fn_list = ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
print(fn_list)
track = fn_list[1]
print("getting matching tracks for ",track)
tracks = get_accompanying_tracks(fn_list[1], fn_list, remove=True)
print("Accompanying tracks are: ",tracks)
print("new list = ",fn_list) # should have the extra 21- tracks removed.
fn_list = ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
fn_list_save = fn_list.copy()
for x in fn_list:
get_accompanying_tracks(x, fn_list, remove=True)
fn_list, fn_list_save
The original dataset class that Christian made, for which we "pack" params and inputs together. This will be loading multichannel wav files
from microtcn.data import SignalTrainLA2ADataset
class SignalTrainLA2ADataset_fastai(SignalTrainLA2ADataset):
"For fastai's sake, have getitem pack the inputs and params together"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, idx):
input, target, params = super().__getitem__(idx)
return torch.cat((input,params),dim=-1), target # pack input and params together
Dataset for loading multiple mono files and packing them together as multichannel:
'''
class MonoToMCDataset(torch.utils.data.Dataset):
"""
UPDATE: turns out we're going to stick to Christian's original dataloader class and just use
conversion scripts to pack or unpack mono WAV files into multichannel WAV files.
----
Modifying Steinmetz' micro-tcn code so we can load the kind of multichannel audio we want.
The difference is that now, we group files that are similar except for a hyphen-designation,
e..g. input_235-1_.wav, input_235-2_.wav get read into one tensor.
The 'trick' will be that we only ever store one filename 'version' of a group of files, but whenever we
want to try to load that file, we will also grab all its associated files.
Like SignalTrain LA2A dataset only more general"""
def __init__(self, root_dir, subset="train", length=16384, preload=False, half=True, fraction=1.0, use_soundfile=False):
"""
Args:
root_dir (str): Path to the root directory of the SignalTrain dataset.
subset (str, optional): Pull data either from "train", "val", "test", or "full" subsets. (Default: "train")
length (int, optional): Number of samples in the returned examples. (Default: 40)
preload (bool, optional): Read in all data into RAM during init. (Default: False)
half (bool, optional): Store the float32 audio as float16. (Default: True)
fraction (float, optional): Fraction of the data to load from the subset. (Default: 1.0)
use_soundfile (bool, optional): Use the soundfile library to load instead of torchaudio. (Default: False)
"""
self.root_dir = root_dir
self.subset = subset
self.length = length
self.preload = preload
self.half = half
self.fraction = fraction
self.use_soundfile = use_soundfile
if self.subset == "full":
self.target_files = glob.glob(os.path.join(self.root_dir, "**", "target_*.wav"))
self.input_files = glob.glob(os.path.join(self.root_dir, "**", "input_*.wav"))
else:
# get all the target files files in the directory first
self.target_files = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "target_*.wav"))
self.input_files = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "input_*.wav"))
self.examples = []
self.minutes = 0 # total number of hours of minutes in the subset
# ensure that the sets are ordered correctlty
self.target_files.sort()
self.input_files.sort()
# get the parameters
self.params = [(float(f.split("__")[1].replace(".wav","")), float(f.split("__")[2].replace(".wav",""))) for f in self.target_files]
# SHH: HERE is where we'll package similar hyphen-designated files together. list comprehension here wouldn't be good btw.
# essentially we are removing 'duplicates'. the first file of each group will be the signifier of all of them
self.target_files_all, self.input_files_all = self.target_files.copy(), self.input_files.copy() # save a copy of original list
for x in self.target_files: # remove extra accompanying tracks from main list that loader will use
get_accompanying_tracks(x, self.target_files, remove=True)
for x in self.input_files:
get_accompanying_tracks(x, self.input_files, remove=True)
# make a dict that will map main file name to list of accompanying files (including itself)
self.target_accomp = {f: get_accompanying_tracks(f, self.target_files_all) for f in self.target_files}
self.input_accomp = {f: get_accompanying_tracks(f, self.input_files_all) for f in self.input_files}
# loop over files to count total length
for idx, (tfile, ifile, params) in enumerate(zip(self.target_files, self.input_files, self.params)):
ifile_id = int(os.path.basename(ifile).split("_")[1])
tfile_id = int(os.path.basename(tfile).split("_")[1])
if ifile_id != tfile_id:
raise RuntimeError(f"Found non-matching file ids: {ifile_id} != {tfile_id}! Check dataset.")
md = torchaudio.info(tfile)
num_frames = md.num_frames
if self.preload:
sys.stdout.write(f"* Pre-loading... {idx+1:3d}/{len(self.target_files):3d} ...\r")
sys.stdout.flush()
input, sr = self.load_accompanying(ifile, self.input_accomp)
target, sr = self.load_accompanying(tfile, self.target_accomp)
num_frames = int(np.min([input.shape[-1], target.shape[-1]]))
if input.shape[-1] != target.shape[-1]:
print(os.path.basename(ifile), input.shape[-1], os.path.basename(tfile), target.shape[-1])
raise RuntimeError("Found potentially corrupt file!")
if self.half:
input = input.half()
target = target.half()
else:
input = None
target = None
# create one entry for each patch
self.file_examples = []
for n in range((num_frames // self.length)):
offset = int(n * self.length)
end = offset + self.length
self.file_examples.append({"idx": idx,
"target_file" : tfile,
"input_file" : ifile,
"input_audio" : input[:,offset:end] if input is not None else None,
"target_audio" : target[:,offset:end] if input is not None else None,
"params" : params,
"offset": offset,
"frames" : num_frames})
# add to overall file examples
self.examples += self.file_examples
# use only a fraction of the subset data if applicable
if self.subset == "train":
classes = set([ex['params'] for ex in self.examples])
n_classes = len(classes) # number of unique compressor configurations
fraction_examples = int(len(self.examples) * self.fraction)
n_examples_per_class = int(fraction_examples / n_classes)
n_min_total = ((self.length * n_examples_per_class * n_classes) / md.sample_rate) / 60
n_min_per_class = ((self.length * n_examples_per_class) / md.sample_rate) / 60
print(sorted(classes))
print(f"Total Examples: {len(self.examples)} Total classes: {n_classes}")
print(f"Fraction examples: {fraction_examples} Examples/class: {n_examples_per_class}")
print(f"Training with {n_min_per_class:0.2f} min per class Total of {n_min_total:0.2f} min")
if n_examples_per_class <= 0:
raise ValueError(f"Fraction `{self.fraction}` set too low. No examples selected.")
sampled_examples = []
for config_class in classes: # select N examples from each class
class_examples = [ex for ex in self.examples if ex["params"] == config_class]
example_indices = np.random.randint(0, high=len(class_examples), size=n_examples_per_class)
class_examples = [class_examples[idx] for idx in example_indices]
extra_factor = int(1/self.fraction)
sampled_examples += class_examples * extra_factor
self.examples = sampled_examples
self.minutes = ((self.length * len(self.examples)) / md.sample_rate) / 60
# we then want to get the input files
print(f"Located {len(self.examples)} examples totaling {self.minutes:0.2f} min in the {self.subset} subset.")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
if self.preload:
audio_idx = self.examples[idx]["idx"]
offset = self.examples[idx]["offset"]
input = self.examples[idx]["input_audio"]
target = self.examples[idx]["target_audio"]
else:
offset = self.examples[idx]["offset"]
input_name = self.examples[idx]["input_file"]
target_name = self.examples[idx]["target_file"]
input = torch.empty((len(self.input_accomp[input_name]), self.length))
for c, fname in enumerate(self.input_accomp[input_name]):
input[c], sr = torchaudio.load(fname,
num_frames=self.length,
frame_offset=offset,
normalize=False)
target = torch.empty((len(self.target_accomp[target_name]), self.length))
for c, fname in enumerate(self.target_accomp[target_name]):
target[c], sr = torchaudio.load(fname,
num_frames=self.length,
frame_offset=offset,
normalize=False)
if self.half:
input = input.half()
target = target.half()
# at random with p=0.5 flip the phase
if np.random.rand() > 0.5:
input *= -1
target *= -1
# then get the tuple of parameters
params = torch.tensor(self.examples[idx]["params"]).unsqueeze(0)
params[:,1] /= 100
return input, target, params
def load(self, filename):
if self.use_soundfile:
x, sr = sf.read(filename, always_2d=True)
x = torch.tensor(x.T)
else:
x, sr = torchaudio.load(filename, normalize=False)
return x, sr
def load_accompanying(self, filename, accomp_dict):
accomp = accomp_dict[filename]
self.num_channels = len(accomp)
md = torchaudio.info(filename) # TODO:fix: assumes all accompanying tracks are the same shape, etc!
num_frames = md.num_frames
data = torch.empty((self.num_channels,num_frames))
for c, afile in enumerate(accomp):
data[c], sr = self.load(afile)
return data, sr
'''
class Args(object): # stand-in for parseargs. these are all micro-tcn defaults
model_type ='tcn'
root_dir = str(path)
preload = False
sample_rate = 44100
shuffle = True
train_subset = 'train'
val_subset = 'val'
train_length = 65536
train_fraction = 1.0
eval_length = 131072
batch_size = 8 # original is 32, my laptop needs smaller, esp. w/o half precision
num_workers = 4
precision = 32 # LEAVE AS 32 FOR NOW: HALF PRECISION (16) NOT WORKING YET -SHH
n_params = 2
args = Args()
#if args.precision == 16: torch.set_default_dtype(torch.float16)
# setup the dataloaders
train_dataset = SignalTrainLA2ADataset_fastai(args.root_dir,
subset=args.train_subset,
fraction=args.train_fraction,
half=True if args.precision == 16 else False,
preload=args.preload,
length=args.train_length)
train_dataloader = torch.utils.data.DataLoader(train_dataset,
shuffle=args.shuffle,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True)
val_dataset = SignalTrainLA2ADataset_fastai(args.root_dir,
preload=args.preload,
half=True if args.precision == 16 else False,
subset=args.val_subset,
length=args.eval_length)
val_dataloader = torch.utils.data.DataLoader(val_dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True)
If the user requested fp16 precision then we need to install NVIDIA apex:
if False and args.precision == 16:
%pip install -q --disable-pip-version-check --no-cache-dir git+https://github.com/NVIDIA/apex
from apex.fp16_utils import convert_network
from microtcn.tcn_bare import TCNModel as TCNModel
#from microtcn.lstm import LSTMModel # actually the LSTM depends on a lot of Lightning stuff, so we'll skip that
from microtcn.utils import center_crop, causal_crop
class TCNModel_fastai(TCNModel):
"For fastai's sake, unpack the inputs and params"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, p=None):
if (p is None) and (self.nparams > 0): # unpack the params if needed
assert len(list(x.size())) == 3 # sanity check
x, p = x[:,:,0:-self.nparams], x[:,:,-self.nparams:]
return super().forward(x, p=p)
# micro-tcn defines several different model configurations. I just chose one of them.
train_configs = [
{"name" : "TCN-300",
"model_type" : "tcn",
"nblocks" : 10,
"dilation_growth" : 2,
"kernel_size" : 15,
"causal" : False,
"train_fraction" : 1.00,
"batch_size" : args.batch_size
}
]
dict_args = train_configs[0]
dict_args["nparams"] = 2
model = TCNModel_fastai(**dict_args)
dtype = torch.float32
Let's take a look at the model:
# this summary allows one to compare the original TCNModel with the TCNModel_fastai
if type(model) == TCNModel_fastai:
torchsummary.summary(model, [(1,args.train_length)], device="cpu")
else:
torchsummary.summary(model, [(1,args.train_length),(1,2)], device="cpu")
Zach Mueller made a very helpful fastai_minima package that we'll use, and follow his instructions.
TODO: Zach says I should either use
fastaiorfastai_minima, not mix them like I'm about to do. But what I have below is the only thing that works right now. ;-)
# I guess we could've imported these up at the top of the notebook...
from torch import optim
from fastai_minima.optimizer import OptimWrapper
#from fastai_minima.learner import Learner # this doesn't include lr_find()
from fastai.learner import Learner
from fastai_minima.learner import DataLoaders
#from fastai_minima.callback.training_utils import CudaCallback, ProgressCallback # note sure if I need these
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, **kwargs))
dls = DataLoaders(train_dataloader, val_dataloader)
if args.precision==16:
dtype = torch.float16
model = convert_network(model, torch.float16)
model = model.to('cuda:0')
if type(model) == TCNModel_fastai:
print("We're using Hawley's modified code")
packed, targ = dls.one_batch()
inp, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
pred = model.forward(packed.to('cuda:0', dtype=dtype))
else:
print("We're using Christian's version of Dataloader and model")
inp, targ, params = dls.one_batch()
pred = model.forward(inp.to('cuda:0',dtype=dtype), p=params.to('cuda:0', dtype=dtype))
print(f"input = {inp.size()}\ntarget = {targ.size()}\nparams = {params.size()}\npred = {pred.size()}")
We can make the pred and target the same length by cropping when we compute the loss:
class Crop_Loss:
"Crop target size to match preds"
def __init__(self, axis=-1, causal=False, reduction="mean", func=nn.L1Loss):
store_attr()
self.loss_func = func()
def __call__(self, pred, targ):
targ = causal_crop(targ, pred.shape[-1]) if self.causal else center_crop(targ, pred.shape[-1])
#pred, targ = TensorBase(pred), TensorBase(targ)
assert pred.shape == targ.shape, f'pred.shape = {pred.shape} but targ.shape = {targ.shape}'
return self.loss_func(pred,targ).flatten().mean() if self.reduction == "mean" else loss(pred,targ).flatten().sum()
# we could add a metric like MSE if we want
def crop_mse(pred, targ, causal=False):
targ = causal_crop(targ, pred.shape[-1]) if causal else center_crop(targ, pred.shape[-1])
return ((pred - targ)**2).mean()
wandb.login()
class WandBAudio(Callback):
"""Progress-like callback: log audio to WandB"""
order = ProgressCallback.order+1
def __init__(self, n_preds=5, sample_rate=44100):
store_attr()
def after_epoch(self):
if not self.learn.training:
with torch.no_grad():
preds, targs = [x.detach().cpu().numpy().copy() for x in [self.learn.pred, self.learn.y]]
log_dict = {}
for i in range(min(self.n_preds, preds.shape[0])): # note wandb only supports mono
log_dict[f"preds_{i}"] = wandb.Audio(preds[i,0,:], caption=f"preds_{i}", sample_rate=self.sample_rate)
wandb.log(log_dict)
wandb.init(project='micro-tcn-fastai')# no name, name=json.dumps(dict_args))
learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func,
cbs= [WandbCallback()])
We can use the fastai learning rate finder to suggest a learning rate:
learn.lr_find(end_lr=0.1)
And now we'll train using the one-cycle LR schedule, with the WandBAudio callback. (Ignore any warning messages)
epochs = 20 # change to 50 for better results but a longer wait
learn.fit_one_cycle(epochs, lr_max=3e-3, cbs=WandBAudio(sample_rate=args.sample_rate))
wandb.finish() # call wandb.finish() after training or your logs may be incomplete
learn.save('micro-tcn-fastai')
Go check out the resulting run logs, graphs, and audio samples at https://wandb.ai/drscotthawley/micro-tcn-fastai, or... lemme see if I can embed some results below:
...ok it looks like the WandB results iframe (with cool graphs & audio) is getting filtered out of the docs (by nbdev and/or jekyll), but if you open this notebook file -- e.g. click the "Open in Colab" badge at the top -- then scroll down and you'll see the report. Or just go to the WandB link posted above!
test_dataset = SignalTrainLA2ADataset_fastai(args.root_dir,
preload=args.preload,
half=True if args.precision == 16 else False,
subset='test',
length=args.eval_length)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True)
learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func, cbs=[])
learn.load('micro-tcn-fastai')
^^ 9 examples? I thought there were only 3:
!ls {path}/Test
...Ok I don't understand that yet. Moving on:
Let's get some predictions from the model. Note that the length of these predictions will greater than in training, because we specified them differently:
print(args.train_length, args.eval_length)
Handy routine to grab some data and run it through the model to get predictions:
def get_pred_batch(dataloader, crop_target=True, causal=False):
packed, target = next(iter(dataloader))
input, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
pred = model.forward(packed.to('cuda:0', dtype=dtype))
if crop_target: target = causal_crop(target, pred.shape[-1]) if causal else center_crop(target, pred.shape[-1])
input, params, target, pred = [x.detach().cpu() for x in [input, params, target, pred]]
return input, params, target, pred
input, params, target, pred = get_pred_batch(test_dataloader, causal=dict_args['causal'])
i = 0 # just look at the first element
print(f"------- i = {i} ---------\n")
print(f"prediction:")
show_audio(pred[i], sample_rate)
print(f"target:")
show_audio(target[i], sample_rate)
TODO: More. We're not finished. I'll come back and add more to this later.
Check out Christian's GitHub page for micro-tcn where he provides instructions and JUCE files by which to render the model as an audio plugin. Pretty sure you can only do this with the causal models, which I didn't include -- yet!